Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add dryrun feature to Dynamo paths #2451

Merged
merged 4 commits into from
Dec 26, 2023
Merged

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented Nov 9, 2023

Description

  • Enables building of TRT engines with "dryrun" capabilities, meaning all of the phases except conversion are run and verbose logs of the graph structure and composition are printed for the user
  • Improves general-purpose debug logging by printing dryrun stats to the debug logs regardless of option specification
  • Provides intuitive schematic of the graph engines, inputs, and code path through the course of the graph
Sample Schematic
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 89 Total Operators, of which 81 operators are supported, 91.01% coverage

The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:
 torch.ops.aten.add.Tensor: 8

The following nodes are currently set to run in Torch:
Node: torch.ops.aten.add.Tensor, with layer location: /layer1/0/add
Node: torch.ops.aten.add.Tensor, with layer location: /layer1/1/add_1
Node: torch.ops.aten.add.Tensor, with layer location: /layer2/0/add_2
Node: torch.ops.aten.add.Tensor, with layer location: /layer2/1/add_3
Node: torch.ops.aten.add.Tensor, with layer location: /layer3/0/add_4
Node: torch.ops.aten.add.Tensor, with layer location: /layer3/1/add_5
Node: torch.ops.aten.add.Tensor, with layer location: /layer4/0/add_6
Node: torch.ops.aten.add.Tensor, with layer location: /layer4/1/add_7
Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner

Compiled with: CompilationSettings(precision=torch.float32, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops={'torch.ops.aten.add.Tensor'}, pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=False, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, dryrun=True)

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 224, 224)@float32]
    ...
    TRT Engine #1 - Submodule name: fused_0
     Engine Inputs: List[Tensor: (1, 512, 7, 7)@float32]
     Number of Operators in Engine: 4
     Engine Outputs: Tensor: (1, 1000)@float32
    ...
    TRT Engine #2 - Submodule name: fused_1
     Engine Inputs: List[Tensor: (1, 512, 7, 7)@float32]
     Number of Operators in Engine: 8
     Engine Outputs: Tuple(Tensor: (1, 512, 7, 7)@float32, Tensor: (1, 512, 7, 7)@float32)
    ...
    TRT Engine #3 - Submodule name: fused_2
     Engine Inputs: List[Tensor: (1, 256, 14, 14)@float32]
     Number of Operators in Engine: 11
     Engine Outputs: Tuple(Tensor: (1, 512, 7, 7)@float32, Tensor: (1, 512, 7, 7)@float32)
    ...
    TRT Engine #4 - Submodule name: fused_3
     Engine Inputs: List[Tensor: (1, 256, 14, 14)@float32]
     Number of Operators in Engine: 8
     Engine Outputs: Tuple(Tensor: (1, 256, 14, 14)@float32, Tensor: (1, 256, 14, 14)@float32)
    ...
    TRT Engine #5 - Submodule name: fused_4
     Engine Inputs: List[Tensor: (1, 128, 28, 28)@float32]
     Number of Operators in Engine: 11
     Engine Outputs: Tuple(Tensor: (1, 256, 14, 14)@float32, Tensor: (1, 256, 14, 14)@float32)
    ...
    TRT Engine #6 - Submodule name: fused_5
     Engine Inputs: List[Tensor: (1, 128, 28, 28)@float32]
     Number of Operators in Engine: 8
     Engine Outputs: Tuple(Tensor: (1, 128, 28, 28)@float32, Tensor: (1, 128, 28, 28)@float32)
    ...
    TRT Engine #7 - Submodule name: fused_6
     Engine Inputs: List[Tensor: (1, 64, 56, 56)@float32]
     Number of Operators in Engine: 11
     Engine Outputs: Tuple(Tensor: (1, 128, 28, 28)@float32, Tensor: (1, 128, 28, 28)@float32)
    ...
    TRT Engine #8 - Submodule name: fused_7
     Engine Inputs: List[Tensor: (1, 64, 56, 56)@float32]
     Number of Operators in Engine: 8
     Engine Outputs: Tuple(Tensor: (1, 64, 56, 56)@float32, Tensor: (1, 64, 56, 56)@float32)
    ...
    TRT Engine #9 - Submodule name: fused_8
     Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32]
     Number of Operators in Engine: 12
     Engine Outputs: Tuple(Tensor: (1, 64, 56, 56)@float32, Tensor: (1, 64, 56, 56)@float32)
    ...
   Outputs: List[Tensor: (1, 1000)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 9.0
   Most Operators in a TRT Engine: 12

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=12 which would generate 1 TRT engine(s)
   - For moderate graph segmentation, select min_block_size=9 which would generate 4 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=4 which generates 9 TRT engine(s)

Fixes #2081
Fixes #2413

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ - ] I have added tests to verify my fix or my feature
    • Validated via CI
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive self-assigned this Nov 9, 2023
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 9, 2023
@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label Nov 9, 2023
@narendasan
Copy link
Collaborator

Can you add a summary of user settings? just dump the struct?

@narendasan
Copy link
Collaborator

Also a list of ops to be run in PyTorch

@narendasan
Copy link
Collaborator

Also this (if it doesnt add to compilation time) should be added as debugging logging to all compilation calls. And if its printed at INFO level it might get masked depending on the users logging settings, I think if people are explicitly calling dry run it should be printed out in STDOUT

@gs-olive gs-olive removed the WIP Work is in progress, pull request should not be merged yet label Nov 10, 2023
@gs-olive gs-olive force-pushed the dryrun_mode branch 2 times, most recently from 0837944 to ffbef2c Compare November 14, 2023 04:39
@github-actions github-actions bot added the component: tests Issues re: Tests label Nov 14, 2023
- Enables building of TRT engines with "dryrun" capabilities, meaning
all of the phases except conversion are run and verbose logs of the
graph structure and composition are printed for the user
- Improves general-purpose debug logging by printing dryrun stats to the
debug logs regardless of option specification
- Provides intuitive schematic of the graph engines, inputs, and code
path through the course of the graph
- Add detailed layer information for excluded ops
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, notes for later:

We should pretty print this Compiled with: CompilationSettings(precision=torch.float32, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops={'torch.ops.aten.add.Tensor'}, pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=False, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, dryrun=True)

Also are we able to see the intermediate torch graphs just like we see the tensorrt ones?

@gs-olive
Copy link
Collaborator Author

gs-olive commented Dec 18, 2023

@narendasan - thanks for the comments. I have added an issue for the suggested improvements here: #2548. We are not able to see inputs to intermediate Torch graphs yet, because they are not always packaged as modules. Specifically, our global partitioner which uses Torch partitioning utilities, only packages the TRT engines as modules and leaves the Torch operators alone, which makes it difficult to group those into subgraphs. This improvement can be added for the fast partitioner, however, and this is noted in the new feature request issue.

@narendasan
Copy link
Collaborator

narendasan commented Dec 18, 2023

Can you not iterate across the graph and just make lists of non tensorrt ops? Even just an idea of what is left out I think is probably some of the more important information this feature could produce, since the goal is all ops are in a TRT engine, the high order bit is what is left out

@@ -295,8 +332,19 @@ def compile_module(
submodule = getattr(partitioned_module, name)
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • In global_partitioning - named_children only returns TRT modules'
  • Can add feature for fast_partitioning first, then expand to global later?

@@ -341,4 +429,6 @@ def compile_module(
if fast_partitioner_failed:
settings.use_fast_partitioner = True

dryrun_stats_display(dryrun_tracker, settings.dryrun)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass in graph including the Torch module/node information for proper display formatting

@gs-olive gs-olive merged commit 15082c4 into pytorch:main Dec 26, 2023
19 checks passed
@gs-olive gs-olive deleted the dryrun_mode branch December 26, 2023 22:58
gs-olive added a commit that referenced this pull request Jan 3, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
gs-olive added a commit that referenced this pull request Jan 4, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
gs-olive added a commit that referenced this pull request Jan 4, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
gs-olive added a commit that referenced this pull request Jan 4, 2024
- Excluded all changes to `docs` and `.github` directories; did include
documentation changes and all other commits, with the exception of #2451
and #2445 for reasons discussed
- Made necessary changes to switch over to Torch 2.2.0 rc builds,
including updating imports
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Dry-Run Functionality for Dynamo ✨[Feature] Compilation Dry-Run in Dynamo
3 participants